#创建模型
import torchvision.models as models
import torch.nn as nn
from torchvision.models import ResNet50_Weights
def get_model(num_classes=10):
    # 加载预训练的 ResNet50 模型，使用新的weights参数替换pretrained=True
    model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
    # 修改最后的全连接层以适应 10 类分类任务
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model